﻿// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

//
// TypeMapInfo.cs
//
// Author:
//   Jb Evain (jbevain@novell.com)
//
// (C) 2009 Novell, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Mono.Cecil;

namespace Mono.Linker
{

    public class TypeMapInfo
    {
        readonly HashSet<AssemblyDefinition> assemblies = new HashSet<AssemblyDefinition>();
        readonly LinkContext context;
        protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> base_methods = new Dictionary<MethodDefinition, List<OverrideInformation>>();
        protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> override_methods = new Dictionary<MethodDefinition, List<OverrideInformation>>();
        protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> default_interface_implementations = new Dictionary<MethodDefinition, List<OverrideInformation>>();

        public TypeMapInfo(LinkContext context)
        {
            this.context = context;
        }

        public void EnsureProcessed(AssemblyDefinition assembly)
        {
            if (!assemblies.Add(assembly))
                return;

            foreach (TypeDefinition type in assembly.MainModule.Types)
                MapType(type);
        }

        public ICollection<MethodDefinition> MethodsWithOverrideInformation => override_methods.Keys;

        /// <summary>
        /// Returns a list of all known methods that override <paramref name="method"/>. The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet
        /// </summary>
        public List<OverrideInformation>? GetOverrides(MethodDefinition method)
        {
            EnsureProcessed(method.Module.Assembly);
            override_methods.TryGetValue(method, out List<OverrideInformation>? overrides);
            return overrides;
        }

        /// <summary>
        /// Returns all base methods that <paramref name="method"/> overrides.
        /// This includes the closest overridden virtual method on <paramref name="method"/>'s base types
        /// methods on an interface that <paramref name="method"/>'s declaring type implements,
        /// and methods an interface implemented by a derived type of <paramref name="method"/>'s declaring type if the derived type uses <paramref name="method"/> as the implementing method.
        /// The list may be incomplete if there are derived types in assemblies that havent been processed yet that use <paramref name="method"/> to implement an interface.
        /// </summary>
        public List<OverrideInformation>? GetBaseMethods(MethodDefinition method)
        {
            EnsureProcessed(method.Module.Assembly);
            base_methods.TryGetValue(method, out List<OverrideInformation>? bases);
            return bases;
        }

        /// <summary>
        /// Returns a list of all default interface methods that implement <paramref name="method"/> for a type.
        /// ImplementingType is the type that implements the interface,
        /// InterfaceImpl is the <see cref="InterfaceImplementation" /> for the interface <paramref name="method" /> is declared on, and
        /// DefaultInterfaceMethod is the method that implements <paramref name="method"/>.
        /// </summary>
        /// <param name="method">The interface method to find default implementations for</param>
        public IEnumerable<OverrideInformation>? GetDefaultInterfaceImplementations(MethodDefinition baseMethod)
        {
            default_interface_implementations.TryGetValue(baseMethod, out var ret);
            return ret;
        }

        public void AddBaseMethod(MethodDefinition method, MethodDefinition @base, InterfaceImplementor? interfaceImplementor)
        {
            base_methods.AddToList(method, new OverrideInformation(@base, method, interfaceImplementor));
        }

        public void AddOverride(MethodDefinition @base, MethodDefinition @override, InterfaceImplementor? interfaceImplementor = null)
        {
            override_methods.AddToList(@base, new OverrideInformation(@base, @override, interfaceImplementor));
        }

        public void AddDefaultInterfaceImplementation(MethodDefinition @base, InterfaceImplementor interfaceImplementor, MethodDefinition defaultImplementationMethod)
        {
            Debug.Assert(@base.DeclaringType.IsInterface);
            default_interface_implementations.AddToList(@base, new OverrideInformation(@base, defaultImplementationMethod, interfaceImplementor));
        }

        Dictionary<TypeDefinition, List<(TypeReference, List<InterfaceImplementation>)>> interfaces = new();
        protected virtual void MapType(TypeDefinition type)
        {
            MapVirtualMethods(type);
            MapInterfaceMethodsInTypeHierarchy(type);
            interfaces[type] = GetRecursiveInterfaceImplementations(type);

            if (!type.HasNestedTypes)
                return;

            foreach (var nested in type.NestedTypes)
                MapType(nested);
        }

        internal List<(TypeReference InterfaceType, List<InterfaceImplementation> ImplementationChain)>? GetRecursiveInterfaces(TypeDefinition type)
        {
            EnsureProcessed(type.Module.Assembly);
            if (interfaces.TryGetValue(type, out var value))
                return value;
            return null;
        }

        List<(TypeReference InterfaceType, List<InterfaceImplementation> ImplementationChain)> GetRecursiveInterfaceImplementations(TypeDefinition type)
        {
            List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain = new();

            AddRecursiveInterfaces(type, [], firstImplementationChain, context);
            Debug.Assert(firstImplementationChain.All(kvp => context.Resolve(kvp.Item1) == context.Resolve(kvp.Item2.Last().InterfaceType)));

            return firstImplementationChain;

            static void AddRecursiveInterfaces(TypeReference typeRef, IEnumerable<InterfaceImplementation> pathToType, List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain, LinkContext Context)
            {
                var type = Context.TryResolve(typeRef);
                // If we can't resolve the interface type we can't find recursive interfaces
                if (type is null)
                    return;
                // Get all explicit interfaces of this type
                foreach (var iface in type.Interfaces)
                {
                    var interfaceType = iface.InterfaceType.InflateFrom(typeRef as IGenericInstance);
                    if (!firstImplementationChain.Any(i => TypeReferenceEqualityComparer.AreEqual(i.Item1, interfaceType, Context)))
                    {
                        firstImplementationChain.Add((interfaceType, pathToType.Append(iface).ToList()));
                    }
                }

                // Recursive interfaces after all direct interfaces to preserve Inherit/Implement tree order
                foreach (var iface in type.Interfaces)
                {
                    var ifaceDirectlyOnType = iface.InterfaceType.InflateFrom(typeRef as IGenericInstance);
                    AddRecursiveInterfaces(ifaceDirectlyOnType, pathToType.Append(iface), firstImplementationChain, Context);
                }
            }
        }

        void MapInterfaceMethodsInTypeHierarchy(TypeDefinition type)
        {
            if (!type.HasInterfaces)
                return;

            // Foreach interface and for each newslot virtual method on the interface, try
            // to find the method implementation and record it.
            foreach (var interfaceImpl in type.GetInflatedInterfaces(context))
            {
                foreach (MethodReference interfaceMethod in interfaceImpl.InflatedInterface.GetMethods(context))
                {
                    MethodDefinition? resolvedInterfaceMethod = context.TryResolve(interfaceMethod);
                    if (resolvedInterfaceMethod == null)
                        continue;

                    // TODO-NICE: if the interface method is implemented explicitly (with an override),
                    // we shouldn't need to run the below logic. This results in ILLink potentially
                    // keeping more methods than needed.

                    if (!resolvedInterfaceMethod.IsVirtual
                        || resolvedInterfaceMethod.IsFinal)
                        continue;

                    // Static methods on interfaces must be implemented only via explicit method-impl record
                    // not by a signature match. So there's no point in running this logic for static methods.
                    if (!resolvedInterfaceMethod.IsStatic)
                    {
                        // Try to find an implementation with a name/sig match on the current type
                        MethodDefinition? exactMatchOnType = TryMatchMethod(type, interfaceMethod);
                        if (exactMatchOnType != null)
                        {
                            AnnotateMethods(resolvedInterfaceMethod, exactMatchOnType, new(type, interfaceImpl.OriginalImpl, resolvedInterfaceMethod.DeclaringType, context));
                            continue;
                        }

                        // Next try to find an implementation with a name/sig match in the base hierarchy
                        var @base = GetBaseMethodInTypeHierarchy(type, interfaceMethod);
                        if (@base != null)
                        {
                            AnnotateMethods(resolvedInterfaceMethod, @base, new(type, interfaceImpl.OriginalImpl, resolvedInterfaceMethod.DeclaringType, context));
                            continue;
                        }
                    }

                    // Look for a default implementation last.
                    FindAndAddDefaultInterfaceImplementations(type, type, resolvedInterfaceMethod, interfaceImpl.OriginalImpl);
                }
            }
        }

        void MapVirtualMethods(TypeDefinition type)
        {
            if (!type.HasMethods)
                return;

            foreach (MethodDefinition method in type.Methods)
            {
                // We do not proceed unless a method is virtual or is static
                // A static method with a .override could be implementing a static interface method
                if (!(method.IsStatic || method.IsVirtual))
                    continue;

                if (method.IsVirtual)
                    MapVirtualMethod(method);

                if (method.HasOverrides)
                    MapOverrides(method);
            }
        }

        void MapVirtualMethod(MethodDefinition method)
        {
            MethodDefinition? @base = GetBaseMethodInTypeHierarchy(method);
            if (@base == null)
                return;

            Debug.Assert(!@base.DeclaringType.IsInterface);

            AnnotateMethods(@base, method);
        }

        void MapOverrides(MethodDefinition method)
        {
            foreach (MethodReference baseMethodRef in method.Overrides)
            {
                MethodDefinition? baseMethod = context.TryResolve(baseMethodRef);
                if (baseMethod == null)
                    continue;
                if (baseMethod.DeclaringType.IsInterface)
                {
                    AnnotateMethods(baseMethod, method, InterfaceImplementor.Create(method.DeclaringType, baseMethod.DeclaringType, context));
                }
                else
                {
                    AnnotateMethods(baseMethod, method);
                }
            }
        }

        void AnnotateMethods(MethodDefinition @base, MethodDefinition @override, InterfaceImplementor? interfaceImplementor = null)
        {
            AddBaseMethod(@override, @base, interfaceImplementor);
            AddOverride(@base, @override, interfaceImplementor);
        }

        MethodDefinition? GetBaseMethodInTypeHierarchy(MethodDefinition method)
        {
            return GetBaseMethodInTypeHierarchy(method.DeclaringType, method);
        }

        MethodDefinition? GetBaseMethodInTypeHierarchy(TypeDefinition type, MethodReference method)
        {
            TypeReference? @base = GetInflatedBaseType(type);
            while (@base != null)
            {
                MethodDefinition? base_method = TryMatchMethod(@base, method);
                if (base_method != null)
                    return base_method;

                @base = GetInflatedBaseType(@base);
            }

            return null;
        }

        TypeReference? GetInflatedBaseType(TypeReference type)
        {
            if (type == null)
                return null;

            if (type.IsGenericParameter || type.IsByReference || type.IsPointer)
                return null;

            if (type is SentinelType sentinelType)
                return GetInflatedBaseType(sentinelType.ElementType);

            if (type is PinnedType pinnedType)
                return GetInflatedBaseType(pinnedType.ElementType);

            if (type is RequiredModifierType requiredModifierType)
                return GetInflatedBaseType(requiredModifierType.ElementType);

            if (type is GenericInstanceType genericInstance)
            {
                var baseType = context.TryResolve(type)?.BaseType;

                if (baseType is GenericInstanceType)
                    return TypeReferenceExtensions.InflateGenericType(genericInstance, baseType);

                return baseType;
            }

            return context.TryResolve(type)?.BaseType;
        }

        /// <summary>
        /// Returns a list of default implementations of the given interface method on this type.
        /// Note that this returns a list to potentially cover the diamond case (more than one
        /// most specific implementation of the given interface methods). ILLink needs to preserve
        /// all the implementations so that the proper exception can be thrown at runtime.
        /// </summary>
        /// <param name="type">The type that implements (directly or via a base interface) the declaring interface of <paramref name="interfaceMethod"/></param>
        /// <param name="interfaceMethod">The method to find a default implementation for</param>
        /// <param name="implOfInterface">
        /// The InterfaceImplementation on <paramref name="type"/> that points to the DeclaringType of <paramref name="interfaceMethod"/>.
        /// </param>
        void FindAndAddDefaultInterfaceImplementations(TypeDefinition typeThatImplementsInterface, TypeDefinition typeThatMayHaveDIM, MethodDefinition interfaceMethodToBeImplemented, InterfaceImplementation originalInterfaceImpl)
        {
            // Go over all interfaces, trying to find a method that is an explicit MethodImpl of the
            // interface method in question.

            foreach (var interfaceImpl in typeThatMayHaveDIM.Interfaces)
            {
                var potentialImplInterface = context.TryResolve(interfaceImpl.InterfaceType);
                if (potentialImplInterface == null)
                    continue;

                bool foundImpl = false;

                foreach (var potentialImplMethod in potentialImplInterface.Methods)
                {
                    if (potentialImplMethod == interfaceMethodToBeImplemented &&
                        !potentialImplMethod.IsAbstract)
                    {
                        AddDefaultInterfaceImplementation(interfaceMethodToBeImplemented, new(typeThatImplementsInterface, originalInterfaceImpl, interfaceMethodToBeImplemented.DeclaringType, context), potentialImplMethod);
                        foundImpl = true;
                        break;
                    }

                    if (!potentialImplMethod.HasOverrides)
                        continue;

                    // This method is an override of something. Let's see if it's the method we are looking for.
                    foreach (var baseMethod in potentialImplMethod.Overrides)
                    {
                        if (context.TryResolve(baseMethod) == interfaceMethodToBeImplemented)
                        {
                            AddDefaultInterfaceImplementation(interfaceMethodToBeImplemented, new(typeThatImplementsInterface, originalInterfaceImpl, interfaceMethodToBeImplemented.DeclaringType, context), @potentialImplMethod);
                            foundImpl = true;
                            break;
                        }
                    }

                    if (foundImpl)
                    {
                        break;
                    }
                }

                // We haven't found a MethodImpl on the current interface, but one of the interfaces
                // this interface requires could still provide it.
                if (!foundImpl)
                {
                    FindAndAddDefaultInterfaceImplementations(typeThatImplementsInterface, potentialImplInterface, interfaceMethodToBeImplemented, originalInterfaceImpl);
                }
            }
        }

        MethodDefinition? TryMatchMethod(TypeReference type, MethodReference method)
        {
            foreach (var candidate in type.GetMethods(context))
            {
                var md = context.TryResolve(candidate);
                if (md?.IsVirtual != true)
                    continue;

                if (MethodMatch(candidate, method))
                    return md;
            }

            return null;
        }

        [SuppressMessage("ApiDesign", "RS0030:Do not used banned APIs", Justification = "It's best to leave working code alone.")]
        static bool MethodMatch(MethodReference candidate, MethodReference method)
        {
            if (candidate.HasParameters != method.HasMetadataParameters())
                return false;

            if (candidate.Name != method.Name)
                return false;

            if (candidate.HasGenericParameters != method.HasGenericParameters)
                return false;

            // we need to track what the generic parameter represent - as we cannot allow it to
            // differ between the return type or any parameter
            if (!TypeMatch(candidate.GetReturnType(), method.GetReturnType()))
                return false;

            if (!candidate.HasMetadataParameters())
                return true;

            var cp = candidate.Parameters;
            var mp = method.Parameters;
            if (cp.Count != mp.Count)
                return false;

            if (candidate.GenericParameters.Count != method.GenericParameters.Count)
                return false;

            for (int i = 0; i < cp.Count; i++)
            {
                if (!TypeMatch(candidate.GetInflatedParameterType(i), method.GetInflatedParameterType(i)))
                    return false;
            }

            return true;
        }

        static bool TypeMatch(IModifierType a, IModifierType b)
        {
            if (!TypeMatch(a.ModifierType, b.ModifierType))
                return false;

            return TypeMatch(a.ElementType, b.ElementType);
        }

        static bool TypeMatch(TypeSpecification a, TypeSpecification b)
        {
            if (a is GenericInstanceType gita)
                return TypeMatch(gita, (GenericInstanceType)b);

            if (a is IModifierType mta)
                return TypeMatch(mta, (IModifierType)b);

            if (a is FunctionPointerType fpta)
                return TypeMatch(fpta, (FunctionPointerType)b);

            return TypeMatch(a.ElementType, b.ElementType);
        }

        static bool TypeMatch(GenericInstanceType a, GenericInstanceType b)
        {
            if (!TypeMatch(a.ElementType, b.ElementType))
                return false;

            if (a.HasGenericArguments != b.HasGenericArguments)
                return false;

            if (!a.HasGenericArguments)
                return true;

            var gaa = a.GenericArguments;
            var gab = b.GenericArguments;
            if (gaa.Count != gab.Count)
                return false;

            for (int i = 0; i < gaa.Count; i++)
            {
                if (!TypeMatch(gaa[i], gab[i]))
                    return false;
            }

            return true;
        }

        static bool TypeMatch(GenericParameter a, GenericParameter b)
        {
            if (a.Position != b.Position)
                return false;

            if (a.Type != b.Type)
                return false;

            return true;
        }

        static bool TypeMatch(FunctionPointerType a, FunctionPointerType b)
        {
            if (a.HasParameters != b.HasParameters)
                return false;

            if (a.CallingConvention != b.CallingConvention)
                return false;

            // we need to track what the generic parameter represent - as we cannot allow it to
            // differ between the return type or any parameter
            if (a.ReturnType is not TypeReference aReturnType ||
                b.ReturnType is not TypeReference bReturnType ||
                !TypeMatch(aReturnType, bReturnType))
                return false;

            if (!a.HasParameters)
                return true;

            var ap = a.Parameters;
            var bp = b.Parameters;
            if (ap.Count != bp.Count)
                return false;

            for (int i = 0; i < ap.Count; i++)
            {
                if (a.Parameters[i].ParameterType is not TypeReference aParameterType ||
                    b.Parameters[i].ParameterType is not TypeReference bParameterType ||
                    !TypeMatch(aParameterType, bParameterType))
                    return false;
            }

            return true;
        }

        static bool TypeMatch(TypeReference a, TypeReference b)
        {
            if (a is TypeSpecification || b is TypeSpecification)
            {
                if (a.GetType() != b.GetType())
                    return false;

                return TypeMatch((TypeSpecification)a, (TypeSpecification)b);
            }

            if (a is GenericParameter genericParameterA && b is GenericParameter genericParameterB)
                return TypeMatch(genericParameterA, genericParameterB);

            return a.FullName == b.FullName;
        }
    }
}
